import matplotlib.pyplot as plt
from teacher_ import teacher_history_vali
from student_teacher import student_history_vali
from student_noteacher import student_no_teacher_history_vali
epochs = 20  # 10个轮次
x = list(range(1, epochs+1))

plt.subplot(2, 1, 1)
plt.plot(x, [teacher_history_vali[i][1] for i in range(epochs)], label='teacher')
plt.plot(x, [student_history_vali[i][1] for i in range(epochs)], label='student with KD')
plt.plot(x, [student_no_teacher_history_vali[i][1] for i in range(epochs)], label='student without KD')

plt.title('Test accuracy')
plt.legend()


plt.subplot(2, 1, 2)
plt.plot(x, [teacher_history_vali[i][0] for i in range(epochs)], label='teacher')
plt.plot(x, [student_history_vali[i][0] for i in range(epochs)], label='student with KD')
plt.plot(x, [student_no_teacher_history_vali[i][0] for i in range(epochs)], label='student without KD')

plt.title('Test loss')
plt.legend()